{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert Distiller Post-Train Quantization Models to \"Native\" PyTorch\n",
    "\n",
    "## Background\n",
    "\n",
    "As of version 1.3 PyTorch comes with built-in quantization functionality. Details are available [here](https://pytorch.org/docs/stable/quantization.html). Distiller's and PyTorch's implementations are completely unrelated. An advantage of PyTorch built-in quantization is that it offers optimized 8-bit execution on CPU and export to GLOW. PyTorch doesn't offer optimized 8-bit execution on GPU (as of version 1.4).\n",
    "\n",
    "At the moment we are still keeping Distiller's separate API and implementation, but we've added the capability to convert a **post-training quantization** model created in Distiller to a \"Distiller-free\" model, comprised entirely of PyTorch built-in quantized modules.\n",
    "\n",
    "Distiller's quantized layers are actually simulated in FP32. Hence, comparing a Distiller model running on CPU to a PyTorch built-in model, the latter will be significantly faster on CPU. However, a Distiller model on a GPU is still likely to be faster compared to a PyTorch model on CPU. So experimenting with Distiller and converting to PyTorch in the end could be useful. Milage may vary of course, depending on the actual HW setup.\n",
    "\n",
    "Let's see how the conversion works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import math\n",
    "import torchnet as tnt\n",
    "from ipywidgets import widgets, interact\n",
    "from copy import deepcopy\n",
    "from collections import OrderedDict\n",
    "\n",
    "import distiller\n",
    "from distiller.models import create_model\n",
    "import distiller.quantization as quant\n",
    "\n",
    "# Load some common code and configure logging\n",
    "# We do this so we can see the logging output coming from\n",
    "# Distiller function calls\n",
    "%run './distiller_jupyter_helpers.ipynb'\n",
    "msglogger = config_notebooks_logger()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# By default, the model is moved to the GPU and parallelized (wrapped with torch.nn.DataParallel)\n",
    "# If no GPU is available, a non-parallel model is created on the CPU\n",
    "model = create_model(pretrained=True, dataset='imagenet', arch='resnet18', parallel=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Data Loaders\n",
    "\n",
    "We create separate data loaders for GPU and CPU. Set `batch_size` and `num_workers` to optimal values that match your HW setup.\n",
    "\n",
    "(Note we reset the seed before creating each data loader, to make sure both loaders consist of the same subset of the test set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We use Distiller's built-in data loading functionality for ImageNet\n",
    "\n",
    "distiller.set_seed(0)\n",
    "\n",
    "subset_size = 1.0 # To save time, can set to value < 1.0\n",
    "dataset = 'imagenet'\n",
    "dataset_path = os.path.expanduser('/datasets/imagenet')\n",
    "arch = 'resnet18'\n",
    "\n",
    "batch_size_gpu = 256\n",
    "num_workers_gpu = 10\n",
    "_, _, test_loader_gpu, _ = distiller.apputils.load_data(\n",
    "    dataset, arch, dataset_path, \n",
    "    batch_size_gpu, num_workers_gpu,\n",
    "    effective_test_size=subset_size, fixed_subset=True, test_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distiller.set_seed(0)\n",
    "batch_size_cpu = 44\n",
    "num_workers_cpu = 10\n",
    "_, _, test_loader_cpu, _ = distiller.apputils.load_data(\n",
    "    dataset, arch, dataset_path, \n",
    "    batch_size_cpu, num_workers_cpu,\n",
    "    effective_test_size=subset_size, fixed_subset=True, test_only=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Evaluation Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_model(data_loader, model, device, print_freq=10):\n",
    "    print('Evaluating model')\n",
    "    criterion = torch.nn.CrossEntropyLoss().to(device)\n",
    "    \n",
    "    loss = tnt.meter.AverageValueMeter()\n",
    "    classerr = tnt.meter.ClassErrorMeter(accuracy=True, topk=(1, 5))\n",
    "\n",
    "    total_samples = len(data_loader.sampler)\n",
    "    batch_size = data_loader.batch_size\n",
    "    total_steps = math.ceil(total_samples / batch_size)\n",
    "    print('{0} samples ({1} per mini-batch)'.format(total_samples, batch_size))\n",
    "\n",
    "    # Switch to evaluation mode\n",
    "    model.eval()\n",
    "\n",
    "    for step, (inputs, target) in enumerate(data_loader):\n",
    "        with torch.no_grad():\n",
    "            inputs, target = inputs.to(device), target.to(device)\n",
    "            # compute output from model\n",
    "            output = model(inputs)\n",
    "\n",
    "            # compute loss and measure accuracy\n",
    "            loss.add(criterion(output, target).item())\n",
    "            classerr.add(output.data, target)\n",
    "            \n",
    "            if (step + 1) % print_freq == 0:\n",
    "                print('[{:3d}/{:3d}] Top1: {:.3f}  Top5: {:.3f}  Loss: {:.3f}'.format(\n",
    "                      step + 1, total_steps, classerr.value(1), classerr.value(5), loss.mean), flush=True)\n",
    "    print('----------')\n",
    "    print('Overall ==> Top1: {:.3f}  Top5: {:.3f}  Loss: {:.3f}'.format(\n",
    "        classerr.value(1), classerr.value(5), loss.mean), flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Post-Train Quantize with Distiller"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quant_mode = {'activations': 'ASYMMETRIC_UNSIGNED', 'weights': 'SYMMETRIC'}\n",
    "stats_file = \"../examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml\"\n",
    "dummy_input = distiller.get_dummy_input(input_shape=model.input_shape)\n",
    "\n",
    "quantizer = quant.PostTrainLinearQuantizer(\n",
    "    deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n",
    "    model_activation_stats=stats_file, overrides=None\n",
    ")\n",
    "quantizer.prepare_model(dummy_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert to PyTorch Built-In"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we trigger the conversion via the Quantizer instance. Later on we show another way which does not\n",
    "# require the quantizer\n",
    "pyt_model = quantizer.convert_to_pytorch(dummy_input)\n",
    "\n",
    "# Note that the converted model is automatically moved to the CPU, regardless\n",
    "# of the device of the Distiller model\n",
    "print('Distiller model device:', distiller.model_device(quantizer.model))\n",
    "print('PyTorch model device:', distiller.model_device(pyt_model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Evaluation\n",
    "### Distiller Model on GPU (if available)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    %time eval_model(test_loader_gpu, quantizer.model, 'cuda')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Distiller Model on CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    print('Creating CPU copy of Distiller model')\n",
    "    cpu_model = distiller.make_non_parallel_copy(quantizer.model).cpu()\n",
    "else:\n",
    "    cpu_model = quantizer.model\n",
    "%time eval_model(test_loader_cpu, cpu_model, 'cpu', print_freq=60)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch model in CPU\n",
    "\n",
    "We expect the PyTorch model on CPU to be much faster than the Distiller model on CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%time eval_model(test_loader_cpu, pyt_model, 'cpu', print_freq=60)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For the Extra-Curious: Comparing the Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Distiller takes care of quantizing the inputs within the quantized modules PyTorch quantized modules assume the input is already quantized. Hence, for cases where a module's input is not quantized, we explicitly add a quantization operation for the input. The first layer in the model, `conv1` in ResNet18, is such a case\n",
    "2. Both Distiller and native PyTorch support fused ReLU. In Distiller, this is somewhat obscurely indicated by the `clip_half_range` attribute inside `output_quant_settings`. In PyTorch, the module type is explicitly `QuantizedConvReLU2d`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('conv1\\n')\n",
    "print('DISTILLER:\\n{}\\n'.format(quantizer.model.module.conv1))\n",
    "print('PyTorch:\\n{}\\n'.format(pyt_model.conv1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Example of internal layers which don't require explicit input quantization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('layer1.0.conv1')\n",
    "print(pyt_model.layer1[0].conv1)\n",
    "print('\\nlayer1.0.add')\n",
    "print(pyt_model.layer1[0].add)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Automatic de-quantization <--> quantization in the model\n",
    "\n",
    "For each quantized module in the Distiller implementation, we quantize the input and de-quantize the output.\n",
    "So, if the user explicitly sets \"internal\" modules to run in FP32, this is transparent to the other quantized modules (at the cost of redundant quant-dequant operations).\n",
    "\n",
    "When converting to PyTorch we remove these redundant operations, and keep just the required ones in case the user explicitly decided to run some modules in FP32.\n",
    "\n",
    "For an example, consider a ResNet \"basic block\" with a residual connection that contains a downsampling convolution. Let's see how such a block looks in our fully-quantized, converted model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(pyt_model.layer2[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see all layers are either built-in quantized PyTorch modules, or identity operations representing fused operations. The entire block is quantized, so we don't see any quant-dequnt operations in the middle.\n",
    "\n",
    "Now let's create a new quantized model, and this time leave the 'downsample' module in FP32:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "overrides = OrderedDict(\n",
    "    [('layer2.0.downsample.0', OrderedDict([('bits_activations', None), ('bits_weights', None)]))]\n",
    ")\n",
    "new_quantizer = quant.PostTrainLinearQuantizer(\n",
    "    deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n",
    "    model_activation_stats=stats_file, overrides=overrides\n",
    ")\n",
    "new_quantizer.prepare_model(dummy_input)\n",
    "\n",
    "new_pyt_model = new_quantizer.convert_to_pytorch(dummy_input)\n",
    "\n",
    "print(new_pyt_model.layer2[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see a few differences:\n",
    "1. The `downsample` module now contains a de-quantize op before the actual convolution\n",
    "2. The `add` module now contains a quantize op before the actual add. Note that the add operation accepts 2 inputs. In this case the first input (index 0) comes from the `conv2` module, which is quantized. The second input (index 1) comes from the `downsample` module, which we kept in FP32. So, we only need to quantized the input at index 1. We can see this is indeed what is happening, by looking at the `ModuleDict` inside the `quant` module, and noticing it has only a single key for index \"1\".\n",
    "\n",
    "Let's see how the `add` module would look if we also kept the `conv2` module in FP32:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "overrides = OrderedDict(\n",
    "    [('layer2.0.downsample.0', OrderedDict([('bits_activations', None), ('bits_weights', None)])),\n",
    "     ('layer2.0.conv2', OrderedDict([('bits_activations', None), ('bits_weights', None)]))]\n",
    ")\n",
    "new_quantizer = quant.PostTrainLinearQuantizer(\n",
    "    deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n",
    "    model_activation_stats=stats_file, overrides=overrides\n",
    ")\n",
    "new_quantizer.prepare_model(dummy_input)\n",
    "\n",
    "new_pyt_model = new_quantizer.convert_to_pytorch(dummy_input)\n",
    "\n",
    "print(new_pyt_model.layer2[0].add)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that now both inputs to the add module are being quantized."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Another API for Conversion\n",
    "\n",
    "In some cases we don't have the actual quantizer. For example - if the Distiller quantized module was loaded from a checkpoint. In those cases we can call a `distiller.quantization` module-level function (In fact, the Quantizer method we used earlier is a wrapper around this function).\n",
    "\n",
    "### Save Distiller model to checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save Distiller model to checkpoint and load it\n",
    "distiller.apputils.save_checkpoint(0, 'resnet18', quantizer.model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Checkpoint\n",
    "\n",
    "The model is quantized when the checkpoint is loaded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = create_model(False, dataset='imagenet', arch='resnet18', parallel=True)\n",
    "loaded_model = distiller.apputils.load_lean_checkpoint(loaded_model, 'checkpoint.pth.tar')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert and Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert\n",
    "loaded_pyt_model = distiller.quantization.convert_distiller_ptq_model_to_pytorch(loaded_model, dummy_input)\n",
    "\n",
    "# Run evaluation\n",
    "%time eval_model(test_loader_cpu, loaded_pyt_model, 'cpu', print_freq=60)\n",
    "\n",
    "# Cleanup\n",
    "os.remove('checkpoint.pth.tar')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
